import numpy as np
from bayespy import nodes
from bayespy.inference.vmp.nodes.categorical import CategoricalMoments
from bayespy.inference import VB
import bayespy.plot as bpplt

import nltk
from nltk.corpus import stopwords
from nltk import word_tokenize
from nltk.stem.wordnet import WordNetLemmatizer

def IsWordValid(x):
    for c in x:
        if not (c.isalpha() or c.isdigit() or c == '-'):
            return False
    return True

stop_words = stopwords.words('english')
min_wordlen = 4
min_words = 15

'''
fw = open('postag_suicide.txt','w')
fr = open('document.tsv','r')
fr.readline()
for line in fr:
    arr = line.strip('\r\n').split('\t')
    s = arr[0]+'\t'+arr[1]
    raw = arr[2]
    phrases = arr[3].split(' ')
    s = ''
    for (word,postag) in nltk.pos_tag(word_tokenize(raw)):
        s += ' '+word+':'+postag
    fw.write(arr[0]+'\t'+arr[1]+'\t'+s[1:]+'\t'+arr[3]+'\n')
fr.close()
fw.close()
'''

'''
fw = open('nonstop_suicide.txt','w')
fr = open('postag_suicide.txt','r')
for line in fr:
    document = []
    arr = line.strip('\r\n').split('\t')
    _arr = arr[2].split(' ')
    for item in _arr:
        pos = item.find(':')
        word,tag = item[:pos],item[pos+1:]
        if len(word) < min_wordlen: continue
        if not (tag.startswith('N') or tag.startswith('V') or tag.startswith('J')): continue
        word = word.lower()
        if tag.startswith('V'): word = WordNetLemmatizer().lemmatize(word,'v')
        if word in stop_words: continue
        if not IsWordValid(word): continue
        document.append(word)
    if len(document) >= min_words:
        s = ''
        for word in document:
            s += ' '+word
        fw.write(arr[0]+'\t'+arr[1]+'\t'+s[1:]+'\t'+arr[3]+'\n')
fr.close()
fw.close()
'''

def separateLDA(post_or_response,n_topics):
    n_vocabulary = 1500
    n_iters = 1500
    n_toofreq_word = 150
    n_toofreq_phrase = 30
    n_top = 50

    raw_word_documents = []
    fr = open('nonstop_suicide.txt','r')
    for line in fr:
        arr = line.strip('\r\n').split('\t')
        if (arr[1] == 'Y' and post_or_response == 'P') or (arr[1] == 'N' and post_or_response == 'R'):
            raw_word_documents.append(arr[2].split(' ')+arr[3].split(' '))
    fr.close()

    n_documents = len(raw_word_documents)
    print('#doc',n_documents)
    print('doc[0]',raw_word_documents[0])

    word2count = {}
    phrase2count = {}
    for id_doc in range(n_documents):
        for word in raw_word_documents[id_doc]:
            if '_' in word:
                if not word in phrase2count:
                    phrase2count[word] = 0
                phrase2count[word] += 1
            else:
                if not word in word2count:
                    word2count[word] = 0
                word2count[word] += 1
    word_count = sorted(word2count.items(),key=lambda x:-x[1])
    phrase_count = sorted(phrase2count.items(),key=lambda x:-x[1])
    wordset = set([x[0] for x in word_count[n_toofreq_word:n_toofreq_word+n_vocabulary]]
            +[x[0] for x in phrase_count[n_toofreq_phrase:n_toofreq_phrase+n_vocabulary]])
    n_vocabulary = len(wordset)

    _raw_word_documents = []
    for id_doc in range(n_documents):
        tokens = []
        for word in raw_word_documents[id_doc]:
            if word in wordset:
                tokens.append(word)
        if len(tokens) < min_words: continue
        _raw_word_documents.append(tokens)
    raw_word_documents = _raw_word_documents

    n_documents = len(raw_word_documents)
    print('#doc',n_documents)
    print('doc[0]',raw_word_documents[0])

    word_documents = []
    corpus = []
    vocabulary = [[],{}]
    for id_doc in range(n_documents):
        for word in raw_word_documents[id_doc]:
            if not word in vocabulary[1]:
                vocabulary[1][word] = len(vocabulary[0])       
                vocabulary[0].append(word)
            id_word = vocabulary[1][word]
            word_documents.append(id_doc)
            corpus.append(id_word)

    n_words = len(word_documents)
    n_vocabulary = len(vocabulary[0])

    print('#word',n_words)
    print('vocab',n_vocabulary)

    word_documents = np.array(word_documents)
    corpus = np.array(corpus)

    ### Stochastic Variational Inference for LDA ###

    subset_size = 1000
    plates_multiplier = int(n_words/subset_size)

    _p_topic_ = nodes.Dirichlet(np.ones(n_topics),plates=(n_documents,),name='p_topic')
    _p_word_ = nodes.Dirichlet(np.ones(n_vocabulary),plates=(n_topics,),name='p_word')

    _document_indices_ = nodes.Constant(CategoricalMoments(n_documents),word_documents[:subset_size],name='document_indices')
    _topics_ = nodes.Categorical(nodes.Gate(_document_indices_,_p_topic_),plates=(subset_size,),plates_multiplier=(plates_multiplier,),name='topics')
    _words_ = nodes.Categorical(nodes.Gate(_topics_,_p_word_),name='words')

    _p_topic_.initialize_from_random()
    _p_word_.initialize_from_random()

    Q = VB(_words_,_topics_,_p_word_,_p_topic_,_document_indices_)
    Q.ignore_bound_checks = True
    delay = 1
    forgetting_rate = 0.7

    for n in range(n_iters):
        # Observe a random mini-batch
        subset = np.random.choice(n_words,subset_size).astype(int)
        Q['words'].observe(corpus[subset])
        Q['document_indices'].set_value(word_documents[subset])
        # Learn intermediate variables
        Q.update('topics')
        # Set step length
        step = (n+delay)**(-forgetting_rate)
        # Stochastic gradient for the global variables
        Q.gradient_step('p_topic','p_word',scale=step)

    # output
    filename = 'topics_'+str(n_topics)+'_'    
    if post_or_response == 'P':
        filename += 'post_'
    elif post_or_response == 'R':
        filename += 'response_'
    fw1 = open(filename+'word.txt','w')
    fw2 = open(filename+'phrase.txt','w')    
    p_word_parameters = Q['p_word'].get_parameters()
    for id_topic in range(n_topics):
        word2score = {}
        phrase2score = {}
        for id_word in range(n_vocabulary):
            word = vocabulary[0][id_word]
            if '_' in word:
                phrase2score[word] = p_word_parameters[0][id_topic][id_word]                
            else:
                word2score[word] = p_word_parameters[0][id_topic][id_word]
        word_score = sorted(word2score.items(),key=lambda x:-x[1])
        phrase_score = sorted(phrase2score.items(),key=lambda x:-x[1])        
        fw1.write('#'+str(id_topic)+'\n')
        fw2.write('#'+str(id_topic)+'\n')        
        for [word,score] in word_score[:n_top]:
            fw1.write(' '+word+' '+str(np.round(score,4))+'\n')
        for [word,score] in phrase_score[:n_top]:
            fw2.write(' '+word+' '+str(np.round(score,4))+'\n')
    fw2.close()
    fw1.close()


#separateLDA('P',5)
#separateLDA('R',5)
#separateLDA('P',20)
#separateLDA('R',20)


def jointLDA(n_topicsP,n_topicsR):
    n_vocabulary = 1500
    n_iters = 1500
    n_toofreq_word = 150
    n_toofreq_phrase = 30
    n_top = 50

    # post vs response

    tid2raw_word_documents = {}
    fr = open('nonstop_suicide.txt','r')
    for line in fr:
        arr = line.strip('\r\n').split('\t')
        tid = arr[0]
        if not tid in tid2raw_word_documents:
            tid2raw_word_documents[tid] = [[],[]] # post, responses
        if arr[1] == 'Y':
            tid2raw_word_documents[tid][0].append(arr[2].split(' ')+arr[3].split(' '))
        elif arr[1] == 'N':
            tid2raw_word_documents[tid][1].append(arr[2].split(' ')+arr[3].split(' '))
    fr.close()

    n_tid = len(tid2raw_word_documents)
    print('#posts',n_tid)
    print('posts[0]',list(tid2raw_word_documents.items())[0][1][0])

    word2countP = {}
    phrase2countP = {}
    word2countR = {}
    phrase2countR = {}
    for [tid,[posts,responses]] in tid2raw_word_documents.items():
        if len(posts) == 0 or len(responses) == 0: continue        
        for word in posts[0]:
            if '_' in word:
                if not word in phrase2countP:
                    phrase2countP[word] = 0
                phrase2countP[word] += 1
            else:
                if not word in word2countP:
                    word2countP[word] = 0
                word2countP[word] += 1
        for response in responses:
            for word in response:
                if '_' in word:
                    if not word in phrase2countR:
                        phrase2countR[word] = 0
                    phrase2countR[word] += 1
                else:
                    if not word in word2countR:
                        word2countR[word] = 0
                    word2countR[word] += 1
    word_countP = sorted(word2countP.items(),key=lambda x:-x[1])
    phrase_countP = sorted(phrase2countP.items(),key=lambda x:-x[1])
    word_countR = sorted(word2countR.items(),key=lambda x:-x[1])
    phrase_countR = sorted(phrase2countR.items(),key=lambda x:-x[1])
    wordsetP = set([x[0] for x in word_countP[n_toofreq_word:n_toofreq_word+n_vocabulary]]
            +[x[0] for x in phrase_countP[n_toofreq_phrase:n_toofreq_phrase+n_vocabulary]])
    wordsetR = set([x[0] for x in word_countR[n_toofreq_word:n_toofreq_word+n_vocabulary]]
            +[x[0] for x in phrase_countR[n_toofreq_phrase:n_toofreq_phrase+n_vocabulary]])
    n_vocabularyP = len(wordsetP)
    n_vocabularyR = len(wordsetR)

    _tid2raw_word_documents = {}
    for [tid,[posts,responses]] in tid2raw_word_documents.items():
        if len(posts) == 0 or len(responses) == 0: continue
        _post = []
        for word in posts[0]:
            if word in wordsetP:
                _post.append(word)
        if len(_post) < min_words: continue
        _responses = []
        for response in responses:
            _response = []
            for word in response:
                if word in wordsetR:
                    _response.append(word)
            if len(_response) < min_words: continue
            _responses.append(_response)
        if len(_responses) == 0: continue
        _tid2raw_word_documents[tid] = [[_post],_responses]
    tid2raw_word_documents = _tid2raw_word_documents

    n_tid = len(tid2raw_word_documents)
    print('#posts',n_tid)
    print('posts[0]',list(tid2raw_word_documents.items())[0][1][0])

    word_documentsP = []
    corpusP = []
    vocabularyP = [[],{}]
    word_documentsR = []
    corpusR = []
    vocabularyR = [[],{}]
    id_docP = -1
    id_docR = -1
    for [tid,[posts,responses]] in tid2raw_word_documents.items():
        id_docP += 1    
        for word in posts[0]:
            if not word in vocabularyP[1]:
                vocabularyP[1][word] = len(vocabularyP[0])       
                vocabularyP[0].append(word)
            id_word = vocabularyP[1][word]
            word_documentsP.append(id_docP)
            corpusP.append(id_word)
        for response in responses:
            id_docR += 1
            for word in response:
                if not word in vocabularyR[1]:
                    vocabularyR[1][word] = len(vocabularyR[0])       
                    vocabularyR[0].append(word)
                id_word = vocabularyR[1][word]
                word_documentsR.append(id_docR)
                corpusR.append(id_word)
    n_documentsP = id_docP+1
    n_documentsR = id_docR+1

    n_wordsP = len(word_documentsP)
    n_vocabularyP = len(vocabularyP[0])
    n_wordsR = len(word_documentsR)
    n_vocabularyR = len(vocabularyR[0])

    print('#wordP',n_wordsP)
    print('vocabP',n_vocabularyP)
    print('#wordR',n_wordsR)
    print('vocabR',n_vocabularyR)

    word_documentsP = np.array(word_documentsP)
    corpusP = np.array(corpusP)
    word_documentsR = np.array(word_documentsR)
    corpusR = np.array(corpusR)

    ### Stochastic Variational Inference for LDA ###

    subset_size = 1000
    plates_multiplierP = int(n_wordsP/subset_size)
    plates_multiplierR = int(n_wordsR/subset_size)

    _p_topicP_ = nodes.Dirichlet(np.ones(n_topicsP),plates=(n_documentsP,),name='p_topicP')
    _p_wordP_ = nodes.Dirichlet(np.ones(n_vocabularyP),plates=(n_topicsP,),name='p_wordP')

    _document_indicesP_ = nodes.Constant(CategoricalMoments(n_documentsP),word_documentsP[:subset_size],name='document_indicesP')
    _topicsP_ = nodes.Categorical(nodes.Gate(_document_indicesP_,_p_topicP_),plates=(subset_size,),plates_multiplier=(plates_multiplierP,),name='topicsP')
    _wordsP_ = nodes.Categorical(nodes.Gate(_topicsP_,_p_wordP_),name='wordsP')

    _p_topicR_ = nodes.Dirichlet(np.ones(n_topicsR),plates=(n_documentsR,),name='p_topicR')
    _p_wordR_ = nodes.Dirichlet(np.ones(n_vocabularyR),plates=(n_topicsR,),name='p_wordR')

    _document_indicesR_ = nodes.Constant(CategoricalMoments(n_documentsR),word_documentsR[:subset_size],name='document_indicesR')
    _topicsR_ = nodes.Categorical(nodes.Gate(_document_indicesR_,_p_topicR_),plates=(subset_size,),plates_multiplier=(plates_multiplierR,),name='topicsR')
    _wordsR_ = nodes.Categorical(nodes.Gate(_topicsR_,_p_wordR_),name='wordsR')

    _p_topicP_.initialize_from_random()
    _p_wordP_.initialize_from_random()

    _p_topicR_.initialize_from_random()
    _p_wordR_.initialize_from_random()

    Q = VB(_wordsP_,_topicsP_,_p_wordP_,_p_topicP_,_document_indicesP_,
            _wordsR_,_topicsR_,_p_wordR_,_p_topicR_,_document_indicesR_)
    Q.ignore_bound_checks = True
    delay = 1
    forgetting_rate = 0.7

    for n in range(n_iters):
        # Observe a random mini-batch
        subsetP = np.random.choice(n_wordsP,subset_size).astype(int)
        subsetR = np.random.choice(n_wordsR,subset_size).astype(int)        
        Q['wordsP'].observe(corpusP[subsetP])
        Q['wordsR'].observe(corpusR[subsetR])
        Q['document_indicesP'].set_value(word_documentsP[subsetP])
        Q['document_indicesR'].set_value(word_documentsR[subsetR])
        # Learn intermediate variables
        Q.update('topicsP')
        Q.update('topicsR')        
        # Set step length
        step = (n+delay)**(-forgetting_rate)
        # Stochastic gradient for the global variables
        Q.gradient_step('p_topicP','p_topicR','p_wordP','p_wordR',scale=step)

    # output
    fw1 = open('joint_topics_'+str(n_topicsP)+'_'+str(n_topicsR)+'_post_word.txt','w')
    fw2 = open('joint_topics_'+str(n_topicsP)+'_'+str(n_topicsR)+'_post_phrase.txt','w')
    p_word_parameters = Q['p_wordP'].get_parameters()
    for id_topic in range(n_topicsP):
        word2score = {}
        phrase2score = {}
        for id_word in range(n_vocabularyP):
            word = vocabularyP[0][id_word]
            if '_' in word:
                phrase2score[word] = p_word_parameters[0][id_topic][id_word]                
            else:
                word2score[word] = p_word_parameters[0][id_topic][id_word]
        word_score = sorted(word2score.items(),key=lambda x:-x[1])
        phrase_score = sorted(phrase2score.items(),key=lambda x:-x[1])        
        fw1.write('#'+str(id_topic)+'\n')
        fw2.write('#'+str(id_topic)+'\n')        
        for [word,score] in word_score[:n_top]:
            fw1.write(' '+word+' '+str(np.round(score,4))+'\n')
        for [word,score] in phrase_score[:n_top]:
            fw2.write(' '+word+' '+str(np.round(score,4))+'\n')
    fw2.close()
    fw1.close()

    fw1 = open('joint_topics_'+str(n_topicsP)+'_'+str(n_topicsR)+'_response_word.txt','w')
    fw2 = open('joint_topics_'+str(n_topicsP)+'_'+str(n_topicsR)+'_response_phrase.txt','w')
    p_word_parameters = Q['p_wordR'].get_parameters()
    for id_topic in range(n_topicsR):
        word2score = {}
        phrase2score = {}
        for id_word in range(n_vocabularyR):
            word = vocabularyR[0][id_word]
            if '_' in word:
                phrase2score[word] = p_word_parameters[0][id_topic][id_word]                
            else:
                word2score[word] = p_word_parameters[0][id_topic][id_word]
        word_score = sorted(word2score.items(),key=lambda x:-x[1])
        phrase_score = sorted(phrase2score.items(),key=lambda x:-x[1])        
        fw1.write('#'+str(id_topic)+'\n')
        fw2.write('#'+str(id_topic)+'\n')        
        for [word,score] in word_score[:n_top]:
            fw1.write(' '+word+' '+str(np.round(score,4))+'\n')
        for [word,score] in phrase_score[:n_top]:
            fw2.write(' '+word+' '+str(np.round(score,4))+'\n')
    fw2.close()
    fw1.close()


#jointLDA(5,10)
#jointLDA(20,20)


def pairLDA(n_topicsP,n_topicsR):
    n_vocabulary = 1500
    n_iters = 1500
    n_toofreq_word = 150
    n_toofreq_phrase = 30
    n_top = 50

    # post vs response

    tid2raw_word_documents = {}
    fr = open('nonstop_suicide.txt','r')
    for line in fr:
        arr = line.strip('\r\n').split('\t')
        tid = arr[0]
        if not tid in tid2raw_word_documents:
            tid2raw_word_documents[tid] = [[],[]] # post, responses
        if arr[1] == 'Y':
            tid2raw_word_documents[tid][0].append(arr[2].split(' ')+arr[3].split(' '))
        elif arr[1] == 'N':
            tid2raw_word_documents[tid][1].append(arr[2].split(' ')+arr[3].split(' '))
    fr.close()

    n_tid = len(tid2raw_word_documents)
    print('#posts',n_tid)
    print('posts[0]',list(tid2raw_word_documents.items())[0][1][0])

    word2countP = {}
    phrase2countP = {}
    word2countR = {}
    phrase2countR = {}
    for [tid,[posts,responses]] in tid2raw_word_documents.items():
        if len(posts) == 0 or len(responses) == 0: continue        
        for word in posts[0]:
            if '_' in word:
                if not word in phrase2countP:
                    phrase2countP[word] = 0
                phrase2countP[word] += 1
            else:
                if not word in word2countP:
                    word2countP[word] = 0
                word2countP[word] += 1
        for response in responses:
            for word in response:
                if '_' in word:
                    if not word in phrase2countR:
                        phrase2countR[word] = 0
                    phrase2countR[word] += 1
                else:
                    if not word in word2countR:
                        word2countR[word] = 0
                    word2countR[word] += 1
    word_countP = sorted(word2countP.items(),key=lambda x:-x[1])
    phrase_countP = sorted(phrase2countP.items(),key=lambda x:-x[1])
    word_countR = sorted(word2countR.items(),key=lambda x:-x[1])
    phrase_countR = sorted(phrase2countR.items(),key=lambda x:-x[1])
    wordsetP = set([x[0] for x in word_countP[n_toofreq_word:n_toofreq_word+n_vocabulary]]
            +[x[0] for x in phrase_countP[n_toofreq_phrase:n_toofreq_phrase+n_vocabulary]])
    wordsetR = set([x[0] for x in word_countR[n_toofreq_word:n_toofreq_word+n_vocabulary]]
            +[x[0] for x in phrase_countR[n_toofreq_phrase:n_toofreq_phrase+n_vocabulary]])
    n_vocabularyP = len(wordsetP)
    n_vocabularyR = len(wordsetR)

    _tid2raw_word_documents = {}
    for [tid,[posts,responses]] in tid2raw_word_documents.items():
        if len(posts) == 0 or len(responses) == 0: continue
        _post = []
        for word in posts[0]:
            if word in wordsetP:
                _post.append(word)
        if len(_post) < min_words: continue
        _responses = []
        for response in responses:
            _response = []
            for word in response:
                if word in wordsetR:
                    _response.append(word)
            if len(_response) < min_words: continue
            _responses.append(_response)
        if len(_responses) == 0: continue
        _tid2raw_word_documents[tid] = [[_post],_responses]
    tid2raw_word_documents = _tid2raw_word_documents

    n_tid = len(tid2raw_word_documents)
    print('#posts',n_tid)
    print('posts[0]',list(tid2raw_word_documents.items())[0][1][0])

    word_documentsP = []
    corpusP = []
    vocabularyP = [[],{}]
    word_documentsR = []
    corpusR = []
    vocabularyR = [[],{}]
    word_documentsRinP = []
    corpusRinP = []
    id_docP = -1
    id_docR = -1
    for [tid,[posts,responses]] in tid2raw_word_documents.items():
        id_docP += 1    
        for word in posts[0]:
            if not word in vocabularyP[1]:
                vocabularyP[1][word] = len(vocabularyP[0])       
                vocabularyP[0].append(word)
            id_word = vocabularyP[1][word]
            word_documentsP.append(id_docP)
            corpusP.append(id_word)
        for response in responses:
            id_docR += 1
            for word in response:
                if not word in vocabularyR[1]:
                    vocabularyR[1][word] = len(vocabularyR[0])       
                    vocabularyR[0].append(word)
                id_word = vocabularyR[1][word]
                word_documentsR.append(id_docR)
                corpusR.append(id_word)
                word_documentsRinP.append(id_docP)
                corpusRinP.append(id_word)
    n_documentsP = id_docP+1
    n_documentsR = id_docR+1

    n_wordsP = len(word_documentsP)
    n_vocabularyP = len(vocabularyP[0])
    n_wordsR = len(word_documentsR)
    n_vocabularyR = len(vocabularyR[0])
    n_wordsRinP = len(word_documentsRinP)

    print('#wordP',n_wordsP)
    print('vocabP',n_vocabularyP)
    print('#wordR',n_wordsR)
    print('vocabR',n_vocabularyR)
    print('#wordRinP',n_wordsRinP)

    word_documentsP = np.array(word_documentsP)
    corpusP = np.array(corpusP)
    word_documentsR = np.array(word_documentsR)
    corpusR = np.array(corpusR)
    word_documentsRinP = np.array(word_documentsRinP)
    corpusRinP = np.array(corpusRinP)

    ### Stochastic Variational Inference for LDA ###

    subset_size = 1000
    plates_multiplierP = int(n_wordsP/subset_size)
    plates_multiplierR = int(n_wordsR/subset_size)
    plates_multiplierRinP = int(n_wordsRinP/subset_size)

    _p_topicP_ = nodes.Dirichlet(np.ones(n_topicsP),plates=(n_documentsP,),name='p_topicP')
    _p_wordP_ = nodes.Dirichlet(np.ones(n_vocabularyP),plates=(n_topicsP,),name='p_wordP')

    _document_indicesP_ = nodes.Constant(CategoricalMoments(n_documentsP),word_documentsP[:subset_size],name='document_indicesP')
    _topicsP_ = nodes.Categorical(nodes.Gate(_document_indicesP_,_p_topicP_),plates=(subset_size,),plates_multiplier=(plates_multiplierP,),name='topicsP')
    _wordsP_ = nodes.Categorical(nodes.Gate(_topicsP_,_p_wordP_),name='wordsP')

    _p_topicR_ = nodes.Dirichlet(np.ones(n_topicsR),plates=(n_documentsR,),name='p_topicR')
    _p_wordR_ = nodes.Dirichlet(np.ones(n_vocabularyR),plates=(n_topicsR,),name='p_wordR')

    _document_indicesR_ = nodes.Constant(CategoricalMoments(n_documentsR),word_documentsR[:subset_size],name='document_indicesR')
    _topicsR_ = nodes.Categorical(nodes.Gate(_document_indicesR_,_p_topicR_),plates=(subset_size,),plates_multiplier=(plates_multiplierR,),name='topicsR')
    _wordsR_ = nodes.Categorical(nodes.Gate(_topicsR_,_p_wordR_),name='wordsR')

    _p_topicRinP_ = nodes.Dirichlet(np.ones(n_topicsR),plates=(n_documentsP,),name='p_topicRinP')
    _document_indicesRinP_ = nodes.Constant(CategoricalMoments(n_documentsP),word_documentsRinP[:subset_size],name='document_indicesRinP')
    _topicsRinP_ = nodes.Categorical(nodes.Gate(_document_indicesRinP_,_p_topicRinP_),plates=(subset_size,),plates_multiplier=(plates_multiplierRinP,),name='topicsRinP')
    _wordsRinP_ = nodes.Categorical(nodes.Gate(_topicsRinP_,_p_wordR_),name='wordsRinP')

    _p_topicP_.initialize_from_random()
    _p_wordP_.initialize_from_random()

    _p_topicR_.initialize_from_random()
    _p_wordR_.initialize_from_random()

    _p_topicRinP_.initialize_from_random()

    Q = VB(_wordsP_,_topicsP_,_p_wordP_,_p_topicP_,_document_indicesP_,
            _wordsR_,_topicsR_,_p_wordR_,_p_topicR_,_document_indicesR_,
            _wordsRinP_,_topicsRinP_,_p_topicRinP_,_document_indicesRinP_)
    Q.ignore_bound_checks = True
    delay = 1
    forgetting_rate = 0.7

    for n in range(n_iters):
        # Observe a random mini-batch
        subsetP = np.random.choice(n_wordsP,subset_size).astype(int)
        subsetR = np.random.choice(n_wordsR,subset_size).astype(int)
        subsetRinP = np.random.choice(n_wordsRinP,subset_size).astype(int)        
        Q['wordsP'].observe(corpusP[subsetP])
        Q['wordsR'].observe(corpusR[subsetR])
        Q['wordsRinP'].observe(corpusRinP[subsetRinP])
        Q['document_indicesP'].set_value(word_documentsP[subsetP])
        Q['document_indicesR'].set_value(word_documentsR[subsetR])
        Q['document_indicesRinP'].set_value(word_documentsRinP[subsetRinP])        
        # Learn intermediate variables
        Q.update('topicsP')
        Q.update('topicsR')
        Q.update('topicsRinP')        
        # Set step length
        step = (n+delay)**(-forgetting_rate)
        # Stochastic gradient for the global variables
        Q.gradient_step('p_topicP','p_topicR','p_wordP','p_wordR','p_topicRinP',scale=step)

    # output
    fw1 = open('pair_topics_'+str(n_topicsP)+'_'+str(n_topicsR)+'_post_word.txt','w')
    fw2 = open('pair_topics_'+str(n_topicsP)+'_'+str(n_topicsR)+'_post_phrase.txt','w')
    p_word_parameters = Q['p_wordP'].get_parameters()
    for id_topic in range(n_topicsP):
        word2score = {}
        phrase2score = {}
        for id_word in range(n_vocabularyP):
            word = vocabularyP[0][id_word]
            if '_' in word:
                phrase2score[word] = p_word_parameters[0][id_topic][id_word]                
            else:
                word2score[word] = p_word_parameters[0][id_topic][id_word]
        word_score = sorted(word2score.items(),key=lambda x:-x[1])
        phrase_score = sorted(phrase2score.items(),key=lambda x:-x[1])        
        fw1.write('#'+str(id_topic)+'\n')
        fw2.write('#'+str(id_topic)+'\n')        
        for [word,score] in word_score[:n_top]:
            fw1.write(' '+word+' '+str(np.round(score,4))+'\n')
        for [word,score] in phrase_score[:n_top]:
            fw2.write(' '+word+' '+str(np.round(score,4))+'\n')
    fw2.close()
    fw1.close()

    fw1 = open('pair_topics_'+str(n_topicsP)+'_'+str(n_topicsR)+'_response_word.txt','w')
    fw2 = open('pair_topics_'+str(n_topicsP)+'_'+str(n_topicsR)+'_response_phrase.txt','w')
    p_word_parameters = Q['p_wordR'].get_parameters()
    for id_topic in range(n_topicsR):
        word2score = {}
        phrase2score = {}
        for id_word in range(n_vocabularyR):
            word = vocabularyR[0][id_word]
            if '_' in word:
                phrase2score[word] = p_word_parameters[0][id_topic][id_word]                
            else:
                word2score[word] = p_word_parameters[0][id_topic][id_word]
        word_score = sorted(word2score.items(),key=lambda x:-x[1])
        phrase_score = sorted(phrase2score.items(),key=lambda x:-x[1])        
        fw1.write('#'+str(id_topic)+'\n')
        fw2.write('#'+str(id_topic)+'\n')        
        for [word,score] in word_score[:n_top]:
            fw1.write(' '+word+' '+str(np.round(score,4))+'\n')
        for [word,score] in phrase_score[:n_top]:
            fw2.write(' '+word+' '+str(np.round(score,4))+'\n')
    fw2.close()
    fw1.close()


#pairLDA(5,10)
#pairLDA(20,20)



